import argparse
import os

from dgl.data.utils import save_graphs, save_info

from gnn.hetgnn.utils import load_data


def preprocess(args):
    g, feats = load_data(args.neighbor_path, args.pretrained_node_embed_path)
    save_graphs(os.path.join(args.save_path, 'neighbor_graph.bin'), [g])
    save_info(os.path.join(args.save_path, 'in_feats.pkl'), feats)
    print('Neighbor graph and input features saved to', args.save_path)


def main():
    parser = argparse.ArgumentParser(description='HetGNN preprocessing')
    parser.add_argument('neighbor_path', help='path to neighbor file generated by random walk')
    parser.add_argument('pretrained_node_embed_path', help='path to pretrained node embeddings')
    parser.add_argument(
        'save_path', help='path to save preprocessed neighbor graph and input features'
    )
    args = parser.parse_args()
    preprocess(args)


if __name__ == '__main__':
    main()
